Skip to content

Conversation

schwarzschild-radius
Copy link
Contributor

This commit adds support for tcgen05.mma instructions in NVPTX with tests under CodeGen/NVPTX/tcgen05-mma*. This tcgen05.mma instructions are modeled as intrinsics with multiple flag arguments to model cta_group, mma kind, collector usage etc. The rationale for the design is documented in NVPTXUsage.rst file. For more details, please refer the PTX ISA

@llvmbot
Copy link
Member

llvmbot commented Aug 4, 2025

@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-backend-nvptx

Author: Pradeep Kumar (schwarzschild-radius)

Changes

This commit adds support for tcgen05.mma instructions in NVPTX with tests under CodeGen/NVPTX/tcgen05-mma*. This tcgen05.mma instructions are modeled as intrinsics with multiple flag arguments to model cta_group, mma kind, collector usage etc. The rationale for the design is documented in NVPTXUsage.rst file. For more details, please refer the PTX ISA


Patch is 386.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/151949.diff

13 Files Affected:

  • (modified) llvm/docs/NVPTXUsage.rst (+385-3)
  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+429-1)
  • (modified) llvm/include/llvm/IR/NVVMIntrinsicUtils.h (+14)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+291)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+38-1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+478-4)
  • (modified) llvm/lib/Target/NVPTX/NVPTXSubtarget.h (+1-1)
  • (added) llvm/test/CodeGen/NVPTX/tcgen05-mma-block-scale-ptx88.ll (+526)
  • (added) llvm/test/CodeGen/NVPTX/tcgen05-mma-block-scale.ll (+291)
  • (added) llvm/test/CodeGen/NVPTX/tcgen05-mma-disable-output-lane.ll (+677)
  • (added) llvm/test/CodeGen/NVPTX/tcgen05-mma-scale-d.ll (+412)
  • (added) llvm/test/CodeGen/NVPTX/tcgen05-mma-ws.ll (+569)
  • (added) llvm/test/CodeGen/NVPTX/tcgen05-mma.ll (+601)
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index d28eb6860c33a..1b61df2cf5254 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -1945,6 +1945,388 @@ The last argument `i1 %unpack` is a compile-time constant which when set, indica
 For more information, refer to the
 `PTX ISA <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
 
+tcgen05.mma Intrinsics
+----------------------
+
+One of the key instructions introduced in the Blackwell architecture is the tcgen05.mma family, which carries out matrix multiply-accumulate operations using the 5th generation Tensor Core unit. The `tcgen05.mma` instruction supports a broad range of capabilities, including sparsity, block scaling, and weight-stationary convolutions. Accurately modeling these through intrinsics is highly complex, and the following table outlines the large number of intrinsics required to fully support the tcgen05.mma instruction set.
+
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| variant                            | Configuration                                                                                     | Total Variants |
++====================================+===================================================================================================+================+
+| tcgen05.mma.shared                 | 2 (space) x 2 (sp) x 4 (kind) x 2 (cta_group) x 4 (collector_usage)                               | 128            |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.tensor.ashift          | 2 (sp) x 4 (kind) x 2 (cta_group) x 2 (collector_usage)                                           | 32             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.scale_d                | 2 (space) x 2 (sp) x 2 (kind) x 2 (cta_group) x 4 (collector_usage)                               | 128            |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.scale_d.tensor.ashift  | 2 (sp) x 2 (kind) x 2 (cta_group) x 2 (collector_usage)                                           | 16             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.disable_output_lane    | 2 (space) x 2 (sp) x 4 (kind) x 2 (cta_group) x 4 (collector_usage)                               | 128            |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.disable_output_lane... | 2 (sp) x 4 (kind) x 2 (cta_group) x 2 (collector_usage)                                           | 32             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.block_scale            | 2 (space) x 1 (mxf4nvf4) x 2 (cta_group) x 2 (scale_vec_size) x 4 (collector_usage)               | 32             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.block_scale            | 2 (space) x 1 (mxf4) x 2 (cta_group) x 2 (scale_vec_size) x 4 (collector_usage)                   | 32             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.block_scale            | 2 (space) x 1 (mxf8f6f4) x 2 (cta_group) x 2 (scale_vec_size) x 4 (collector_usage)               | 32             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.ws                     | 2 (space) x 2 (sp) x 4 (kind) x 2 (zero_col_mask) x 4 (collector_usage_op) x 4 (collector_buffer) | 256            |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| Total                              |                                                                                                   | 816            |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+
+To reduce the number of possible intrinsic variations, we've modeled the tcgen05.mma instructions using flag operands. We've added range checks to these flags to prevent invalid values. We also expanded some flags back into intrinsic modifiers to avoid supporting invalid combinations of features.
+
+'``llvm.nvvm.tcgen05.mma.*``'
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+  declare void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+
+  ; .sp variants
+  declare void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group, i32 %collector_usage_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+
+  ; .scale_d variants
+  declare void @llvm.nvvm.tcgen05.mma.shared.f16.scale_d(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.shared.tf32.scale_d(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.f16.scale_d(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.f16.scale_d.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.tf32.scale_d(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.tf32.scale_d.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+
+  ; sp.scale_d variants
+  declare void @llvm.nvvm.tcgen05.mma.sp.shared.f16.scale_d(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.shared.tf32.scale_d(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.f16.scale_d(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.f16.scale_d.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.tf32.scale_d(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.tf32.scale_d.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+
+Overview:
+"""""""""
+
+`nvvm.tcgen05.mma` is an asynchronous intrinsic which initiates an `MxNxK` matrix multiply and accumulate operation, `D = A * B + D` where the `A` matrix is `M x K`, the `B` matrix is `K x N`, and the `D` matrix is `M x N`. The operation of the form `D = A*B` is issued when the input predicate argument `%enable_inp_d` is false. The optional immediate argument `%scale_d_imm` can be specified to scale the input matrix `D` as follows: `D = A * B + D * (2 ^ - %scale_d_imm)`. The valid range of values for argument `%scale_d_imm` is `[0, 15]`. The 32-bit register operand idesc is the instruction descriptor as described in `Instruction descriptor <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor>`__
+
+`nvvm.tcgen05.mma` has single thread semantics, unlike the collective instructions `nvvm.mma.sync` or the PTX `wgmma.mma_async` instruction. So, a single thread issuing the `nvvm.tcgen05.mma` will result in the initiation of the whole matrix multiply and accumulate operation
+
+When `.sp` is specifed, the dimension of A matrix is `M x (K/2)` and requires specifiying an additional `%spmetadata` argument
+
+`.ashift` shifts the rows of the A matrix down by one row, except for the last row in the Tensor Memory. `.ashift` is only allowed with M = 128 or M = 256.
+
+The `%collector_usage_a_op_flag` flag specifies the usage of collector buffer for matrix `A`. It is illegal to specify either of `USE` or `FILL` for `%collector_usage_a_op_flag` along with `.ashift`
+
+For more information, refer to the
+`PTX ISA <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__
+
+The following tables describes the possible values of the flag arguments
+
+`%kind_flag` flag:
+
+============= ==========
+  `kind_flag`   value
+============= ==========
+     F16          0
+     TF32         1
+     F8F6F4       2
+     I8           3
+============= ==========
+
+`%cta_group` flag:
+
+============= ==========
+ `cta_group`    value
+============= ==========
+     CG1          1
+     CG2          2
+============= ==========
+
+`%collector_usage_a_op_flag` flag:
+
+============================= ==========
+ `collector_usage_a_op_flag`    value
+============================= ==========
+     DISCARD                      0
+     LASTUSE                      1
+     USE                          2
+     FILL                         3
+============================= ==========
+
+'``llvm.nvvm.tcgen05.mma.block_scale*``'
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+  ; mxf8f6f4
+  declare void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) %d, addrspace(3) %a, addrspace(3) %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block32.block_scale(ptr addrspace(6) %d, addrspace(3) %a, addrspace(3) %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block32.block_scale(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+
+  ; mxf4
+  declare void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) %d, addrspace(3) %a, addrspace(3) %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) %d, addrspace(3) %a, addrspace(3) %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+
+  ; mxf4nvf4
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block16.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block16.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block16.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block16.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+
+Overview:
+"""""""""
+`nvvm.tcgen05.mma.block_scale` is an asynchronous intrinsic which initiates an `MxNxK` matrix multiply and accumulate operation, `D = (A * scale_a)  * (B * scale_a) + D` where the `A` matrix is `M x K`, the `B` matrix is `K x N`, and the `D` matrix is `M x N`. The matrices `A` and `B` are scaled with `%scale_A` and `%scale_B` matrices respectively before performing the matrix multiply and accumulate operation. The operation of the form `D = A*B` is issued when the input predicate argument `%enable_inp_d` is false. The 32-bit register operand idesc is the instruction descriptor as described in `Instruction descriptor <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor>`__
+
+`nvvm.tcgen05.mma.block_scale` has single thread semantics, unlike the collective instructions `nvvm.mma.sync` or the PTX `wgmma.mma_async` instruction. So, a single thread issuing the `nvvm.tcgen05.mma.block_scale` will result in the initiation of the whole matrix multiply and accumulate operation
+
+When `.sp` is specifed, the dimension of A matrix is `Mx(K/2)` and requires specifiying an additional `%spmetadata` argument
+
+The `%collector_usage_a_op_flag` flag specifies the usage of collector buffer for matrix `A`
+
+For more information, refer to the
+`PTX ISA <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__
+
+The following tables describes the possible values of the flag arguments
+
+`%kind_flag` flag:
+
+============= ==========
+ `kind_flag`    value
+============= ==========
+  MXF8F6F4        0
+  MXF4            1
+  MXF4NVF4        2
+============= ==========
+
+`%cta_group` flag:
+
+============= ==========
+ `cta_group`    value
+============= ==========
+     CG1          1
+     CG2          2
+============= ==========
+
+`%collector_usage_a_op_flag` flag:
+
+========================...
[truncated]

Copy link

github-actions bot commented Aug 4, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@schwarzschild-radius schwarzschild-radius force-pushed the tcgen05_mma_nvptx_support branch 2 times, most recently from 8dd88b7 to 8d4b4f5 Compare August 4, 2025 14:46
@schwarzschild-radius
Copy link
Contributor Author

@Artem-B @AlexMaclean Can you please help with the review?


`.ashift` shifts the rows of the A matrix down by one row, except for the last row in the Tensor Memory. `.ashift` is only allowed with M = 128 or M = 256.

The `%collector_usage_a_op_flag` flag specifies the usage of collector buffer for matrix `A`. It is illegal to specify either of `USE` or `FILL` for `%collector_usage_a_op_flag` along with `.ashift`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The explanation could be improved. Right now it does not make a lot of sense unless you're already familiar with what it's supposed to do.

E.g. "flag specifies the usage" by itself does not explain much. It would be more helpful to make it more specific. Something along the lines of "if %collector_usage_a_op_flag is set to X, then intrinsic does/generates Y. It's ".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Artem-B @durga4github and I had a discussion and we have added an overview section detailing the mechanics of the tcgen05.mma instruction. Please let us know if that is sufficient

.. code-block:: llvm

; mxf8f6f4
declare void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) %d, addrspace(3) %a, addrspace(3) %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't %kind_flag redundant heres, considering that the intrinsic already has mxf8f6f4 in its name?

Applies to other intrinsics below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Artem-B Thanks for catching that! I have remove the %kind_flag from the intrinsic definition as it is not required

llvm_i32_ty, // 3. idesc
llvm_i1_ty, // 4. enable_inp_d
llvm_tmem_ptr_ty, // 5. spmetadata
llvm_v8i32_ty], // 6. disable output lane
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like using illegal types in these intrinsics has forced you to add a lot of additional logic in NVPTX isel lowering. Is there a reason we need to use these types? Could these vectors be flattened into a series of arguments?

This commit adds support for tcgen05.mma instructions in NVPTX which tests under CodeGen/NVPTX/tcgen05-mma*. This tcgen05.mma instructions are modeled as intrinsics with multiple arguments to model cta_group, mma kind, collector usage etc. The rationale for the design is present documented in NVPTXUsage.rst file
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants